// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;

import java.util.Set;

/**
 * Push the predicate down through generate.
 */
public class PushDownFilterThroughGenerate extends OneRewriteRuleFactory {
    public static final PushDownFilterThroughGenerate INSTANCE = new PushDownFilterThroughGenerate();

    /**
     * filter-generate to generate->filter
     */
    @Override
    public Rule build() {
        return logicalFilter(logicalGenerate()).then(filter -> {
            LogicalGenerate<Plan> generate = filter.child();
            Set<Slot> childOutputs = generate.child().getOutputSet();
            Set<Expression> pushDownPredicates = Sets.newHashSet();
            Set<Expression> remainPredicates = Sets.newHashSet();
            filter.getConjuncts().forEach(conjunct -> {
                Set<Slot> conjunctSlots = conjunct.getInputSlots();
                if (!conjunctSlots.isEmpty() && childOutputs.containsAll(conjunctSlots)) {
                    pushDownPredicates.add(conjunct);
                } else {
                    remainPredicates.add(conjunct);
                }
            });
            if (pushDownPredicates.isEmpty()) {
                return null;
            }
            Plan bottomFilter = new LogicalFilter<>(pushDownPredicates, generate.child(0));
            generate = generate.withChildren(ImmutableList.of(bottomFilter));
            return PlanUtils.filterOrSelf(remainPredicates, generate);
        }).toRule(RuleType.PUSH_DOWN_FILTER_THROUGH_GENERATE);
    }
}
